package org.databandtech.flinkstreaming;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.databandtech.flinkstreaming.Entity.ItemCount;

/*
 * 因为是以windowEnd为key，所以输入类型为Long
 * 输入对象类型为ItemCount
 * 输出是一个TopN排序后的列表
 */
public class KafkaTopNKeyedFunction extends KeyedProcessFunction<Long, ItemCount, List<ItemCount>> {


	private static final long serialVersionUID = -4057567950417032227L;
	private int topNSize;
	//用于存储状态，待收齐同一个窗口的数据后，再触发 Top N 计算
    private ListState<ItemCount> itemState;
	
	public KafkaTopNKeyedFunction(int topNSize) {
		this.topNSize = topNSize;
	}
	
    @Override
    public void open(Configuration parameters) throws Exception {
        //状态注册
        ListStateDescriptor<ItemCount> itemViewStateDesc = new ListStateDescriptor<ItemCount>(
                "itemstate", ItemCount.class
        );
        itemState = getRuntimeContext().getListState(itemViewStateDesc);
    }
	
	@Override
	public void processElement(ItemCount input, 
			KeyedProcessFunction<Long, ItemCount, List<ItemCount>>.Context context,
			Collector<List<ItemCount>> collector) throws Exception {
		//输入的每条数据都保存到状态
    	System.out.println("TOP input processElement:"+input);
        itemState.add(input);
        context.timerService().registerEventTimeTimer(input.getWindowEnd()+10000);//fire timer
	}
	
	public void onTimer(long timestamp, OnTimerContext ctx, Collector<List<ItemCount>> out) throws Exception {
		System.out.println("TOP onTimer:"+timestamp);
		//根据状态获取收集到的所有窗体时间内数据
        List<ItemCount> allItems = new ArrayList<ItemCount>();
        for(ItemCount item : itemState.get()) {
            allItems.add(item);
        }
        //清除状态
        itemState.clear();
        
        allItems.sort(new Comparator<ItemCount>() {
            @Override
            public int compare(ItemCount o1, ItemCount o2) {
                return (int) (o2.getCount() - o1.getCount());
            }
        });
        //待输出的排序后的
        List<ItemCount> finalOut = new ArrayList<>();
        
        for (int i=0;i<topNSize;i++) {
        	ItemCount currentItem = allItems.get(i);
        	finalOut.add(currentItem);
        }
        
        out.collect(finalOut);
	}

 

 
}
